Skip to content

fix(server): prompt-cache bleed fixes — MambaCache gate + ndim guard + spec-decode ordering#85

Merged
solderzzc merged 3 commits into
SharpAI:mainfrom
ericjlake:perf/combined
Apr 26, 2026
Merged

fix(server): prompt-cache bleed fixes — MambaCache gate + ndim guard + spec-decode ordering#85
solderzzc merged 3 commits into
SharpAI:mainfrom
ericjlake:perf/combined

Conversation

@ericjlake
Copy link
Copy Markdown
Contributor

@ericjlake ericjlake commented Apr 26, 2026

Closes (part of) #84.

This PR is one of two paired patches landing the work from #84. Companion PR in mlx-swift-lm adds the needsMoeFlush gate that produces the headline 18 → 63 tok/s speedup on full-RAM Qwen3-A3B; this PR fixes three correctness bugs in SwiftLM's prompt-cache path that were silently regressing throughput and quality on chat-template replays. Bonus: a small README perf table addition (per @solderzzc's request in #84).

What changed

Sources/SwiftLM/Server.swift — three prompt-cache bleed fixes

1. MambaCache safety gate in save()

Adds an early return when any layer in the cache is a MambaCache. Mamba's recurrent state can't be partially trimmed (unlike attention's offset-decrement), so the cache cannot be safely saved/restored for hybrid Attention+Mamba models. Without this guard, cache.trim(N) on a MambaCache layer hits an unrelated assertion path. Same direction as upstream 5553bf5 (disable prompt cache for MambaCache hybrid models) — this is the explicit guard at the save site.

2. KVCacheSimple T-dim slice in save()

For attention KVCacheSimple layers, the state tensor is [B, H, T, D] with a pre-allocated T that can exceed the actual prompt length P. If we store the full over-sized buffer, restore()'s trim(cached.tokens.count - matchLen) still leaves T - P slots of garbage beyond the valid prefix. We now slice T down to P at save time so cached.tokens.count === cached state's T.

3. ndim >= 3 guard inside minCachedSeqLen scan

When no caches are yet populated, the min-search returned 0 and the restore path short-circuited the prefill, producing a "1-token-then-EOS" pattern on warm replays. The guard fails closed: if the state tensor doesn't have the expected rank, skip it in the min calculation rather than silently returning a degenerate result.

4. Spec-decode short-circuit ordering in process()

The hasDraftModel short-circuit now runs before the prompt-cache restore decision. Reason: a partial-match cache restore corrupts the draft model's KV state, since draft and main cycle tokens in lock-step. Better to pay the full prefill than emit garbage. This is independent of #72's auto-cap (which addresses I/O fan-out for --stream-experts + --draft-model); this change is about correctness on the in-RAM spec-decode path.

README.md — Performance subsection for full-RAM Qwen3-A3B

Adds a new subsection to the Performance section with reproducible numbers across Vanilla and DFlash-spec-decode configurations on M1 Ultra 64 GB, mirroring the existing M5 Pro Gemma-4 table style. Per @solderzzc's suggestion in #84 — happy to iterate on layout/content.

Hardware / repro

Test plan

  • Tested with Qwen3.6-35B-A3B-UD-MLX-4bit on M1 Ultra 64 GB — 3.4× steady-state generation speedup (18 → 63 tok/s) with companion mlx-swift-lm PR applied.
  • No regressions in default-flag inference path (warm-cache replays return correct token counts, no 1-token-then-EOS pattern).
  • Spec-decode short-circuit verified to bypass cache lookup when --draft-model is set.

Companion PR

SharpAI/mlx-swift-lm PR with the needsMoeFlush gate: SharpAI/mlx-swift-lm#34

Eric and others added 2 commits April 22, 2026 04:05
Three fixes, now riding on upstream 116ee91:

1. save(): slice KVCacheSimple state T-dim down to P=tokens.count so the
   cached states' T matches cached.tokens.count. Prevents the over-allocated
   prefill buffer from carrying uninitialized tokens past the valid prefix.

2. restore(): gate out recurrent-state layers (MambaCache and friends) up
   front. Their state is 2-D with no T dimension, so the dim(2) read in the
   pre-flight check would crash; also there's no trim(excess) operator for
   a recurrent hidden state — we can't partial-restore one safely. Guard
   with ndim>=3 inside the min-length scan too for belt-and-suspenders.

3. handleChatCompletion(): reorder the decision branch so speculative
   decoding is checked BEFORE the prompt cache restore. A cache-hit rollback
   corrupts the draft model's KV state (draft and main cycle tokens in
   lock-step), so when draftModelRef is set we bypass the cache entirely
   and pay the full prefill. Partial-match restores stay available on the
   non-spec path where they still pay off.
Adds a new Performance subsection covering full-RAM Qwen3.6-35B-A3B-UD-MLX-4bit
inference on M1 Ultra 64 GB:
- Vanilla full-GPU (62 tok/s) — post needsMoeFlush gate (SwiftLM SharpAI#84)
- DFlash spec decode with z-lab/Qwen3.6-35B-A3B-DFlash (+13% medium/long,
  -15% short due to block overhead, finish_reason behavior changes)

Includes 19→62 tok/s before/after reference for the gate fix.
@solderzzc
Copy link
Copy Markdown
Member

Hey Eric — this has merge conflicts against current main because the DFlash integration (PR #78 from @0xClandestine) landed after you forked. The conflicts are in the Server.swift decision branch — we added a DFlash early-return block and a skipPromptCache guard (for kvBits) that your diff doesn't know about.

We're going to merge main into your branch to resolve this on our end, since we wrote those changes and can resolve the conflicts accurately. Your original commits will stay intact — just an additional merge commit on top.

Also — the companion PR (mlx-swift-lm#34) is merged! ✅

Merges ericjlake's prompt-cache fixes from PR SharpAI#85, resolving conflicts
with the DFlash integration (PR SharpAI#78).

Changes from ericjlake:
- MambaCache safety gate + KVCacheSimple T-dim slice in save()
- ndim >= 3 guard in minCachedSeqLen scan
- Spec-decode short-circuit ordering (check before cache restore)
- README: Qwen3-A3B full-RAM perf table (M1 Ultra 64 GB)

Conflict resolution:
- README.md: kept both Qwen3-A3B and DeepSeek-V4 perf tables
- Server.swift save(): kept existing MambaCache early return + new T-dim slice
- Server.swift decision branch: combined spec-decode-first + skipPromptCache (kvBits)

Closes SharpAI#84.
Co-authored-by: Eric Lake <ericjlake@users.noreply.github.com>
@solderzzc
Copy link
Copy Markdown
Member

Update: we tried to push the conflict resolution directly to your perf/combined branch but hit a permissions block ("Allow edits from maintainers" appears to be disabled on this PR).

Instead, we sent the resolution as a PR to your fork: ericjlake#1

Once you merge that, this PR will be conflict-free and we can land it here. All your original commits are preserved — just one additional merge commit on top.

@solderzzc solderzzc reopened this Apr 26, 2026
@solderzzc solderzzc merged commit 7df2170 into SharpAI:main Apr 26, 2026
26 checks passed
solderzzc pushed a commit that referenced this pull request Apr 26, 2026
Follow-up to #85 (just merged). Subsequent benchmarking discovered the
70 tok/s DFlash medium/long numbers in that PR were ALWAYS degenerate
output ("and and and...", "**UMA** **UMA**...") — high acceptance because
draft and target both committed to the same locked-in token every block.
Root cause: DFlash uses argMax greedy regardless of request temperature.
Vanilla samples stochastically at temp=0.6 which breaks ties; DFlash has
no tie-breaker and locks into low-entropy attractors.

Mitigation experiments (rep-penalty 1.1, 1.3) only partially help: 1.1
is too weak to dislodge hard attractors (1/5 prompts clean), 1.3 fixes
attractors but acceptance crashes 80%->18-46% so DFlash becomes net-
negative below vanilla. Proper fix is stochastic posterior sampling with
rejection-based accept (Leviathan/Chen), tracked at z-lab/dflash#91.

Replaces the misleading row with a clear warning so users do not adopt
a degenerate codepath as the recommended config.

See z-lab/dflash#91 (issuecomment 4322584783) for the full diagnosis.
solderzzc added a commit that referenced this pull request Apr 26, 2026
docs(README): remove degenerate DFlash perf row from #85 perf table
solderzzc added a commit that referenced this pull request Apr 27, 2026
test: PromptCache regression tests — model-free coverage for PR #85
ajunlonglive pushed a commit to ajunlonglive/SwiftLM that referenced this pull request May 1, 2026
17 unit tests protecting the prompt-cache save/restore contract:

Group 1 — save() guards:
  - MambaCache gate (skip save for hybrid models)
  - T-dim slice to P (prevent garbage beyond valid prefix)
  - Small T-dim preservation (no-op when T <= P)
  - Pure KVCacheSimple smoke test

Group 2 — restore() guards:
  - MambaCache rejection in target cache
  - Recurrent layer detection (ArraysCache bail)
  - Full/partial/no match paths
  - Empty cache graceful miss
  - Sliding window trim safety

Group 3 — Decision branch ordering:
  - skipPromptCache (multimodal, kv_bits)
  - Spec-decode checked before cache restore
  - Cache hit used when no draft model

Group 4 — Stats tracking

Zero model downloads, 0.04s runtime. Uses synthetic KVCache
instances with tiny MLXArray tensors ([1,2,T,4]).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants